"""
Phase 1: Data Assembly — Sectoral Savings Decomposition Panel
==============================================================
Merges full_panel.csv with WDI savings/investment indicators to
test whether Z predicts household vs corporate vs government
savings separately — the corporate savings glut hypothesis.

Note: Full sectoral financial balances (OECD) are limited to ~38
countries. We supplement with WDI gross savings, gross national
savings, and investment to construct residual corporate savings.

Output: sectoral_savings/data/processed/sectoral_panel.csv
Tables: summary_statistics.md
"""

import sys
from pathlib import Path
import time

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/sectoral_savings")
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
FISCAL_DIR = ROOT_DIR / "fiscal_dominance"
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
RAW_DIR = PROJECT_DIR / "data" / "raw"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

for d in [PROCESSED_DIR, RAW_DIR, TABLES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


def download_wdi_indicator(indicator, name, countries, max_retries=3):
    """Download a WDI indicator via World Bank API."""
    import urllib.request
    import json

    print(f"  Downloading WDI {indicator} ({name}) ...")
    all_rows = []
    chunk_size = 15

    for i in range(0, len(countries), chunk_size):
        chunk = countries[i:i + chunk_size]
        country_str = ';'.join(chunk)
        url = (f"https://api.worldbank.org/v2/country/{country_str}/"
               f"indicator/{indicator}?format=json&per_page=10000&date=1970:2024")

        for attempt in range(max_retries):
            try:
                req = urllib.request.Request(url)
                req.add_header('User-Agent', 'Mozilla/5.0')
                with urllib.request.urlopen(req, timeout=30) as resp:
                    data = json.loads(resp.read().decode())
                if len(data) >= 2 and data[1]:
                    for obs in data[1]:
                        if obs.get('value') is not None:
                            all_rows.append({
                                'iso3': obs['countryiso3code'],
                                'year': int(obs['date']),
                                name: float(obs['value']),
                            })
                break
            except Exception as e:
                if attempt == max_retries - 1:
                    print(f"    Failed chunk {i//chunk_size}: {e}")
                time.sleep(1.5)

        time.sleep(0.5)

    if all_rows:
        df = pd.DataFrame(all_rows)
        df = df.drop_duplicates(subset=['iso3', 'year'])
        print(f"    {name}: {df['iso3'].nunique()} countries, {len(df)} obs")
        return df
    print(f"    {name}: no data retrieved")
    return pd.DataFrame(columns=['iso3', 'year', name])


def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly — Sectoral Savings Decomposition")
    print("=" * 70)

    # ── [1] Load full_panel.csv ──
    print("\n[1] Loading full_panel.csv ...")
    full_path = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    df = pd.read_csv(full_path)
    df = df[df['year'] <= 2024].copy()
    countries = sorted(df['iso3'].unique())
    print(f"  Full panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # ── [2] Check existing savings/investment vars ──
    print("\n[2] Existing savings/investment variables ...")
    for v in ['gross_savings_gdp', 'gross_national_savings_gdp',
              'gross_investment_gdp', 'gross_fixed_investment_gdp',
              'savings_investment_gap', 'fiscal_bal_gdp']:
        if v in df.columns:
            n = df[v].notna().sum()
            nc = df.loc[df[v].notna(), 'iso3'].nunique()
            print(f"  {v}: {n:,} obs, {nc} countries")

    # ── [3] Download additional WDI savings indicators ──
    print("\n[3] Downloading WDI savings/consumption indicators ...")

    raw_cache = RAW_DIR / "wdi_savings.csv"
    if raw_cache.exists():
        print("  Loading cached WDI savings data ...")
        sav_df = pd.read_csv(raw_cache)
    else:
        indicators = {
            'NY.GNS.ICTR.ZS': 'gross_savings_wdi',           # Gross savings (% GNI)
            'NY.ADJ.NNAT.GN.ZS': 'adj_net_national_savings',  # Adjusted net national savings (% GNI)
            'NE.CON.GOVT.ZS': 'govt_consumption_gdp',         # Government final consumption (% GDP)
            'NE.CON.PRVT.ZS': 'private_consumption_gdp',      # Household final consumption (% GDP)
            'NE.CON.TOTL.ZS': 'total_consumption_gdp',        # Final consumption (% GDP)
            'GC.XPN.TOTL.GD.ZS': 'govt_expenditure_wdi',      # Expense (% GDP)
            'GC.REV.XGRT.GD.ZS': 'govt_revenue_wdi',          # Revenue excl grants (% GDP)
        }

        sav_dfs = []
        for indicator, name in indicators.items():
            idf = download_wdi_indicator(indicator, name, countries)
            if len(idf) > 0:
                sav_dfs.append(idf)

        if sav_dfs:
            sav_df = sav_dfs[0]
            for idf in sav_dfs[1:]:
                sav_df = sav_df.merge(idf, on=['iso3', 'year'], how='outer')
            sav_df.to_csv(raw_cache, index=False)
        else:
            sav_df = pd.DataFrame(columns=['iso3', 'year'])

    if len(sav_df) > 0:
        existing = set(df.columns)
        new_cols = [c for c in sav_df.columns if c not in existing and c not in ['iso3', 'year']]
        if new_cols:
            merge_df = sav_df[['iso3', 'year'] + new_cols].drop_duplicates(
                subset=['iso3', 'year'])
            df = df.merge(merge_df, on=['iso3', 'year'], how='left')
            for col in new_cols:
                n = df[col].notna().sum()
                print(f"  {col}: {n:,} obs")

    # ── [4] Load fiscal panel for government savings ──
    print("\n[4] Loading fiscal panel ...")
    fiscal_path = FISCAL_DIR / "data" / "processed" / "fiscal_panel.csv"
    if fiscal_path.exists():
        fisc = pd.read_csv(fiscal_path)
        fiscal_vars = ['govt_revenue_gdp', 'govt_expenditure_gdp',
                       'govt_debt_gdp', 'primary_bal_gdp']
        existing = set(df.columns)
        merge_cols = ['iso3', 'year'] + [c for c in fiscal_vars
                                          if c in fisc.columns and c not in existing]
        if len(merge_cols) > 2:
            fisc_merge = fisc[merge_cols].drop_duplicates(subset=['iso3', 'year'])
            df = df.merge(fisc_merge, on=['iso3', 'year'], how='left')

    # ── [5] Construct sectoral savings decomposition ──
    print("\n[5] Constructing sectoral decomposition ...")

    # Government saving = revenue - expenditure (from fiscal panel or WDI)
    if 'govt_revenue_gdp' in df.columns and 'govt_expenditure_gdp' in df.columns:
        df['govt_saving_gdp'] = df['govt_revenue_gdp'] - df['govt_expenditure_gdp']
        print(f"  govt_saving_gdp: {df['govt_saving_gdp'].notna().sum():,} obs")
    elif 'govt_revenue_wdi' in df.columns and 'govt_expenditure_wdi' in df.columns:
        df['govt_saving_gdp'] = df['govt_revenue_wdi'] - df['govt_expenditure_wdi']
        print(f"  govt_saving_gdp (WDI): {df['govt_saving_gdp'].notna().sum():,} obs")

    # Private saving = gross_national_savings - govt_saving
    if 'gross_national_savings_gdp' in df.columns and 'govt_saving_gdp' in df.columns:
        df['private_saving_gdp'] = df['gross_national_savings_gdp'] - df['govt_saving_gdp']
        print(f"  private_saving_gdp: {df['private_saving_gdp'].notna().sum():,} obs")

    # Household saving proxy = GDP - total_consumption - govt_consumption (rough)
    if 'private_consumption_gdp' in df.columns:
        df['household_consumption_share'] = df['private_consumption_gdp']
        # Household saving ≈ 100 - private_consumption - govt_consumption - investment
        # This is an approximation; true household saving needs sectoral accounts

    # Corporate saving = private saving - household saving (residual)
    # Without direct household savings data, we approximate:
    # Corporate saving ≈ private_saving - (1 - private_consumption/100) * GDP
    # This is inherently noisy — we'll note the limitation

    # ── [6] Interaction terms ──
    print("\n[6] Interaction terms ...")
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    for zv in demo_vars:
        if zv in df.columns:
            if 'kaopen' in df.columns:
                df[f'{zv}_x_kaopen'] = df[zv] * df['kaopen']
            if 'trade_openness' in df.columns:
                df[f'{zv}_x_trade'] = df[zv] * df['trade_openness']

    df = df.sort_values(['iso3', 'year'])
    for zv in demo_vars:
        if zv in df.columns:
            df[f'{zv}_lag5'] = df.groupby('iso3')[zv].shift(5)
            df[f'd_{zv}'] = df.groupby('iso3')[zv].diff()

    # ── [7] Regime variables ──
    df['oecd'] = df['iso3'].isin(OECD_38).astype(int)

    if 'gdp_pc_ppp' in df.columns:
        def safe_qcut(x):
            try:
                return pd.qcut(x, 3, labels=['low', 'mid', 'high'], duplicates='drop')
            except (ValueError, IndexError):
                return pd.Series(np.nan, index=x.index)
        df['income_group'] = df.groupby('year')['gdp_pc_ppp'].transform(safe_qcut)

    # ── [8] Restrict and summarize ──
    print("\n[8] Panel summary ...")
    df = df[(df['year'] >= 1990) & (df['year'] <= 2024)].copy()

    n_total = len(df)
    n_countries = df['iso3'].nunique()
    print(f"  Total panel: {n_countries} countries, {n_total:,} obs")

    # ── [9] Summary statistics ──
    key_vars = ['gross_national_savings_gdp', 'gross_savings_gdp',
                'gross_investment_gdp', 'savings_investment_gap',
                'govt_saving_gdp', 'private_saving_gdp',
                'private_consumption_gdp', 'govt_consumption_gdp',
                'adj_net_national_savings',
                'Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                'ca_gdp', 'fiscal_bal_gdp', 'kaopen', 'log_gdp_pc']
    key_vars = [v for v in key_vars if v in df.columns]

    md = ["# Summary Statistics — Sectoral Savings Panel\n"]
    md.append("| Variable | N | Mean | SD | Min | Max |")
    md.append("|---|---|---|---|---|---|")
    for v in key_vars:
        s = df[v].dropna()
        if len(s) > 0:
            md.append(f"| {v} | {len(s):,} | {s.mean():.3f} | {s.std():.3f} "
                      f"| {s.min():.3f} | {s.max():.3f} |")
    out = TABLES_DIR / "summary_statistics.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")

    # ── [10] Save ──
    print("\n[10] Saving sectoral_panel.csv ...")
    df.to_csv(PROCESSED_DIR / "sectoral_panel.csv", index=False)
    print(f"  Saved: {PROCESSED_DIR / 'sectoral_panel.csv'}")
    print(f"  Shape: {df.shape[0]:,} obs x {df.shape[1]} columns")

    print("\n" + "=" * 70)
    print("Phase 1 complete.")
    print("=" * 70)

    return df


if __name__ == "__main__":
    df = main()
